import vanna
import pandas as pd
import mysql.connector
from vanna.remote import VannaDefault


def run_sql(sql: str) -> pd.DataFrame:
    cnx = mysql.connector.connect(user='root',password='123456',host='localhost',database='student')
    cursor = cnx.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    columns = cursor.column_names
    df = pd.DataFrame(result, columns=columns)
    return df


api_key = '56651f10131246a6b634d7f5d5564f70'
vanna_model_name = 'ansWhite123'
vn = VannaDefault(model=vanna_model_name, api_key=api_key)
vn.run_sql = run_sql
vn.run_sql_is_set = True
# vn.train(ddl="""
# DROP TABLE IF EXISTS `student_info`;
# CREATE TABLE `student_info`  (
#   `id` int NULL DEFAULT NULL,
#   `name` varchar(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL,
#   `classId` varchar(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL,
#   `hometown` varchar(20) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NULL DEFAULT NULL,
#   `chineseScore` int NULL DEFAULT NULL,
#   `mathScore` int NULL DEFAULT NULL,
#   `englishScore` int NULL DEFAULT NULL,
#   `totalScore` int NULL DEFAULT NULL
# ) ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci ROW_FORMAT = Dynamic;
#
# """)

# vn.train(question='统计不同民族数量？', sql='SELECT nation, COUNT(*) as count FROM customer GROUP BY nation ORDER BY count DESC;')
# vn.train(question='张三的工作部门是什么？', sql='SELECT department FROM employee where id_card=(SELECT id_card FROM customer WHERE name="张三");')
print("1111")
vn.generate_sql('李四的岗位是什么？')
vn.generate_sql('职务是主管的电话号码是什么？')
#vn.ask('李四的岗位是什么？')

from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn, allow_llm_to_see_data=True)
app.run()
